{% extends 'base.exported.class' %}

{% block content %}
import com.google.gson.Gson;

import java.io.File;
import java.io.FileNotFoundException;

import java.util.Arrays;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Scanner;


class {{ class_name }} {

    private class Classifier {
        private double[][] X;
        private int[] y;
        private int k;
        private int n;
        private double power;
    }

    private class Neighbor {
        Integer y;
        Double dist;

        private Neighbor(int y, double dist) {
            this.y = y;
            this.dist = dist;
        }
    }

    private Classifier clf;

    public {{ class_name }}(String file) throws FileNotFoundException {
        String jsonStr = new Scanner(new File(file)).useDelimiter("\\Z").next();
        this.clf = new Gson().fromJson(jsonStr, Classifier.class);
    }

    private int findMax(double[] nums) {
        int idx = 0;
        for (int i = 1; i < nums.length; i++) {
            idx = nums[i] > nums[idx] ? i : idx;
        }
        return idx;
    }

    private static double compute(double[] temp, double[] cand, double q) {
        double dist = 0.;
        double diff;
        for (int i = 0, l = temp.length; i < l; i++) {
            diff = Math.abs(temp[i] - cand[i]);
            if (q == 1) {
                dist += diff;
            } else if (q == 2) {
                dist += diff * diff;
            } else if (q == Double.POSITIVE_INFINITY) {
                if (diff > dist) {
                    dist = diff;
                }
            } else {
                dist += Math.pow(diff, q);
            }
        }
        if (q == 1 || q == Double.POSITIVE_INFINITY) {
            return dist;
        } else if (q == 2) {
            return Math.sqrt(dist);
        } else {
            return Math.pow(dist, 1. / q);
        }
    }

    public int predict(double[] features) {
        double[] classProbas = this.predictProba(features);
        return findMax(classProbas);
    }

    public double[] predictProba(double[] features) {
        int classIdx = 0;
        double[] classProbas = new double[this.clf.n];
        double dist;
        if (this.clf.k == 1) {
            double minDist = Double.POSITIVE_INFINITY;
            for (int i = 0; i < this.clf.y.length; i++) {
                dist = {{ class_name }}.compute(this.clf.X[i], features, this.clf.power);
                if (dist <= minDist) {
                    minDist = dist;
                    classIdx = this.clf.y[i];
                }
            }
            classProbas[classIdx] = 1;
        } else {
            ArrayList<Neighbor> dists = new ArrayList<Neighbor>();
            for (int i = 0; i < this.clf.y.length; i++) {
                dist = {{ class_name }}.compute(this.clf.X[i], features, this.clf.power);
                dists.add(new Neighbor(clf.y[i], dist));
            }
            Collections.sort(dists, new Comparator<Neighbor>() {
                @Override
                public int compare(Neighbor n1, Neighbor n2) {
                    return n1.dist.compareTo(n2.dist);
                }
            });
            for (Neighbor neighbor : dists.subList(0, this.clf.k)) {
                classProbas[neighbor.y]++;
            }
            for (int i = 0; i < this.clf.n; i++) {
                classProbas[i] /= this.clf.k;
            }
        }
        return classProbas;
    }

    public static void main(String[] args) throws FileNotFoundException {

        // Features:
        double[] features = new double[args.length-1];
        for (int i = 1, l = args.length; i < l; i++) {
            features[i - 1] = Double.parseDouble(args[i]);
        }

        // Model data:
        String modelData = args[0];

        // Estimator:
        {{ class_name }} clf = new {{ class_name }}(modelData);

        {% if is_test or to_json %}
        // Get JSON:
        int prediction = clf.predict(features);
        double[] probabilities = clf.predictProba(features);
        System.out.println("{\"predict\": " + String.valueOf(prediction) + ", \"predict_proba\": " + String.join(",", Arrays.toString(probabilities)) + "}");
        {% else %}
        // Get class prediction:
        int prediction = clf.predict(features);
        System.out.println("Predicted class: #" + String.valueOf(prediction));

        // Get class probabilities:
        double[] probabilities = clf.predictProba(features);
        for (int i = 0; i < probabilities.length; i++) {
            System.out.print(String.valueOf(probabilities[i]));
            if (i != probabilities.length - 1) {
                System.out.print(",");
            }
        }
        {% endif %}

    }

}
{% endblock %}